/**
 * Copyright 2021 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include <memory>
#include "execute_graph_builder.h"
#include "execute_graph.h"
#include "common_reg.h"
namespace ge {

namespace {
PoolOffset AddNode(ExecuteGraphBuilder &gb, const char *name, PoolOffset op_def,
                   std::initializer_list<std::string> inputs, std::initializer_list<std::string> outputs) {
  auto node_offset = gb.AddNode();
  gb.GetNodeBuilder(node_offset)->SetName(name).SetOpDef(op_def);
  int i = 0;
  for (auto &input : inputs) {
    gb.GetNodeBuilder(node_offset)->SetInputName(i++, input.c_str());
  }
  i = 0;
  for (auto &output : outputs) {
    gb.GetNodeBuilder(node_offset)->SetOutputName(i++, output.c_str());
  }

  return node_offset;
}
}  // namespace

/**
 *    20210814034028_online_Node_Output
 *                    |
 *                  MatMul
 *                  /   \
 * trans_TransData_0    trans_TransData_1
 *        |                    |
 *   MatMul_in_0           MatMul_in_1
 */
std::unique_ptr<ExecuteGraph> BuildBaselineGraph1() {
  ExecuteGraphBuilder gb;
  auto reg = RegOpDefs(gb);

  auto data0 = AddNode(gb, "MatMul_in_0", reg.data, {"x"}, {"y"});
  auto data1 = AddNode(gb, "MatMul_in_1", reg.data, {"x"}, {"y"});
  auto trans0 = AddNode(gb, "trans_TransData_0", reg.transdata, {"src"}, {"dst"});
  auto trans1 = AddNode(gb, "trans_TransData_1", reg.transdata, {"src"}, {"dst"});
  auto matmul = AddNode(gb, "MatMul", reg.matmul, {"x1", "x2"}, {"y"});
  auto netoutput = AddNode(gb, "20210814034028_online_Node_Output", reg.netoutput, {"x1"}, {"y1"});

  gb.GetNodeBuilder(trans0)->SetAttr(0, std::string("ND")).SetAttr(1, std::string("FRACTAL_NZ"));
  gb.GetNodeBuilder(trans1)->SetAttr(0, std::string("ND")).SetAttr(1, std::string("FRACTAL_NZ"));

  gb.AddEdge(data0, 0, trans0, 0)
  .AddEdge(data1, 0, trans1, 0)
  .AddEdge(trans0, 0, matmul, 0)
  .AddEdge(trans1, 0, matmul, 1)
  .AddEdge(matmul, 0, netoutput, 0);

  auto graph = gb.Build();

  auto data0_desc = graph->GetNodeByIndex(0)->GetOpDesc();
  data0_desc->MutableOutputDesc(0)->SetOriginFormat(FORMAT_ND);
  data0_desc->MutableOutputDesc(0)->SetFormat(FORMAT_ND);
  data0_desc->MutableOutputDesc(0)->SetOriginDataType(DT_FLOAT16);
  data0_desc->MutableOutputDesc(0)->SetDataType(DT_FLOAT16);
  data0_desc->MutableOutputDesc(0)->SetOriginShape(Shape({256, 256}));
  data0_desc->MutableOutputDesc(0)->SetShape(Shape({256, 256}));

  auto data1_desc = graph->GetNodeByIndex(1)->GetOpDesc();
  data1_desc->MutableOutputDesc(0)->SetOriginFormat(FORMAT_ND);
  data1_desc->MutableOutputDesc(0)->SetFormat(FORMAT_ND);
  data1_desc->MutableOutputDesc(0)->SetOriginDataType(DT_FLOAT16);
  data1_desc->MutableOutputDesc(0)->SetDataType(DT_FLOAT16);
  data1_desc->MutableOutputDesc(0)->SetOriginShape(Shape({256, 256}));
  data1_desc->MutableOutputDesc(0)->SetShape(Shape({256, 256}));

  auto trans0_desc = graph->GetNodeByIndex(2)->GetOpDesc();
  trans0_desc->MutableOutputDesc(0)->SetOriginFormat(FORMAT_ND);
  trans0_desc->MutableOutputDesc(0)->SetFormat(FORMAT_FRACTAL_NZ);
  trans0_desc->MutableOutputDesc(0)->SetOriginDataType(DT_FLOAT16);
  trans0_desc->MutableOutputDesc(0)->SetDataType(DT_FLOAT16);

  auto trans1_desc = graph->GetNodeByIndex(3)->GetOpDesc();
  trans1_desc->MutableOutputDesc(0)->SetOriginFormat(FORMAT_ND);
  trans1_desc->MutableOutputDesc(0)->SetFormat(FORMAT_FRACTAL_NZ);
  trans1_desc->MutableOutputDesc(0)->SetOriginDataType(DT_FLOAT16);
  trans1_desc->MutableOutputDesc(0)->SetDataType(DT_FLOAT16);

  auto mm_desc = graph->GetNodeByIndex(4)->GetOpDesc();
  mm_desc->MutableOutputDesc(0)->SetOriginFormat(FORMAT_ND);
  mm_desc->MutableOutputDesc(0)->SetFormat(FORMAT_FRACTAL_NZ);
  mm_desc->MutableOutputDesc(0)->SetOriginDataType(DT_FLOAT16);
  mm_desc->MutableOutputDesc(0)->SetDataType(DT_FLOAT16);

  return std::move(graph);
}
}